2016ICPC Dalian G 点分治

点分治+枚举一个数二进制的子集(和2018CCPC秦皇岛枚举子集一样)


题意

给一颗点染色树,问有多少对 $(u,v)$,使得在此路径上的种类数有 $k$ 个。

分析

  • 显然点分治题
  • 重点在于统计数量的复杂度。
  • 二进制枚举一个数的子集

    1
    for(int  j=i; j; j=(i-1)&i)
  • 但 $vector$ 初始化一直 $RE$,调了十年(毒瘤UVALive)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
using namespace std;
#define ll long long
const int maxn = 500000+10;

vector<int>g[maxn];
int n, k, u, v, col[maxn];
int d[maxn], cnt, sz[maxn], maxsz[maxn], minx, sum, minid;
ll ans;
bool vis[maxn];
ll state[1200];

void init()
{
ans=cnt=0;
memset(vis, false, sizeof(vis));
// for(int i=1;i<=n;i++) g[i].clear();

}

void dfs1(int u,int fa)
{
// cout<<"--"<<5<<endl;
maxsz[u]=0; sz[u]=1;
int len=g[u].size();
for(int i=0;i<len;i++)
{
int v=g[u][i];
if(vis[v] || v==fa) continue;
dfs1(v, u);
sz[u]+=sz[v];
maxsz[u] = max(maxsz[u], sz[v]);
}
}

void dfs2(int u, int fa)
{
// cout<<"--"<<4<<endl;
int temp=max(maxsz[u], sum-maxsz[u]);
if(temp < minx)
{
minx=temp;
minid=u;
}
int len=g[u].size();
for(int i=0;i<len;i++)
{
int v=g[u][i];
if(v==fa||vis[v] ) continue;
dfs2(v, u);
}

}

void getcol(int u,int fa, int dis)
{

d[++cnt]=dis;
int len=g[u].size();
//cout<<"--"<<3<<' '<<len<<endl;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
//cout<<v<<' '<<vis[v]<<endl;
if(v==fa || vis[v]) continue;
// cout<<"--"<<3<<' '<<v<<endl;
getcol(v, u, (dis|(1<<col[v])));
}
}

int getroot(int u)
{
// cout<<"--"<<2<<endl;
dfs1(u, 0);
minx=n;
minid=-1;
sum=sz[u];
dfs2(u, 0);
return minid;
}

ll calc(int u, int val)
{
ll aans=0;
cnt=0;
getcol(u, 0, val);
for(int i=0;i<(1<<k);i++) state[i]=0;
for(int i=1;i<=cnt;i++) state[d[i]]++;
for(int i=1;i<=cnt;i++) {
state[d[i]]--;
aans += state[(1<<k)-1];
for(int j=d[i]; j ; j=(j-1)&d[i])
aans += state[ ((1<<k)-1) ^ j];
state[d[i]]++;
}
return aans;
}

void solve(int u)
{
int root=getroot(u);
ans += calc(root, (1<<col[root]));
vis[root]=1;
int len=g[root].size();
for(int i=0;i<len;i++)
{
int v=g[root][i];
if(vis[v]) continue;
ans -= calc(v,(1<<col[root])|(1<<col[v]));
solve(v);
}
}

int main()
{
while(scanf("%d%d", &n, &k)!=EOF)
{

if(n==0&&k==0) break;
init();
for(int i=1;i<=n;i++)
{
scanf("%d", &col[i]);
--col[i];
}
for(int i=1;i<=n-1;i++)
{
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
if(k==1) {
ll nn=n;
printf("%lld\n", nn*nn);
continue;
}
// cout<<123<<endl;
solve(1);
for(int i=1; i<=n; i++) g[i].clear();
printf("%lld\n" , ans);

}
return 0;
}